/*
 * SPDX-FileCopyrightText: Copyright (c) 2021-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
 * SPDX-License-Identifier: Apache-2.0
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#include "dropout.cuh"
#include "softmax.cuh"
#include <strided_batched_gemm.cuh>
#include <torch/extension.h>

std::vector<torch::Tensor> fwd_cuda(torch::Tensor const& qkv, torch::Tensor const& pad_mask, float dropout_prob)
{
    // sxbxhx3xd
    int const head_dim = qkv.size(4);
    int const heads = qkv.size(2);
    int const sequences = qkv.size(1);
    int const q_seq_len = qkv.size(0);
    int const k_seq_len = q_seq_len;
    int const attn_batches = heads * sequences;
    int const lead_dim = attn_batches * 3 * head_dim;
    int const batch_stride = 3 * head_dim;
    float const alpha = 1.0;
    float const beta_zero = 0.0;
    float const scale = 1.0 / sqrt(static_cast<float>(head_dim));

    cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();

    // 3 Intermediate Results + Output (Note: dropout intermediates are generated by ATen library
    // code)
    auto act_options = qkv.options().requires_grad(false);
    auto mask_options = act_options.dtype(torch::kUInt8);

    torch::Tensor bmm1_results = torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options);
    torch::Tensor dropout_results = torch::empty({attn_batches, q_seq_len, k_seq_len}, act_options);
    torch::Tensor dropout_mask = torch::empty({attn_batches, q_seq_len, k_seq_len}, mask_options);
    torch::Tensor matmul2_results = torch::empty({q_seq_len, attn_batches, head_dim}, act_options);

    // Input Linear Results Pointers to Q, K, and V of interviewed activations
    half const* q_lin_results_ptr = static_cast<half const*>(qkv.data_ptr());
    half const* k_lin_results_ptr = static_cast<half const*>(qkv.data_ptr()) + head_dim;
    half const* v_lin_results_ptr = static_cast<half const*>(qkv.data_ptr()) + 2 * head_dim;

    // Softmax Intermediate Result Ptr (used by Matmul1 -> Softmax)
    half* bmm1_results_ptr = static_cast<half*>(bmm1_results.data_ptr());
    half* dropout_results_ptr = static_cast<half*>(dropout_results.data_ptr());

    char a_layout_t{'t'};
    char a_layout_n{'n'};
    char b_layout_n{'n'};

    // MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size)
    gemm_switch_fp32accum(a_layout_t, b_layout_n, k_seq_len, q_seq_len, head_dim, scale, k_lin_results_ptr, lead_dim,
        batch_stride, q_lin_results_ptr, lead_dim, batch_stride, beta_zero, bmm1_results_ptr, k_seq_len,
        k_seq_len * q_seq_len, attn_batches);

    // Padded Softmax
    bool softmax_success = false;
    softmax_success = dispatch_additive_masked_softmax_dropout<half, half, float>(dropout_results_ptr,
        dropout_mask.data_ptr<uint8_t>(), bmm1_results_ptr, static_cast<half*>(pad_mask.data_ptr()),
        attn_batches * q_seq_len * q_seq_len, k_seq_len, k_seq_len, attn_batches * q_seq_len,
        attn_batches * q_seq_len / sequences, 1.0f - dropout_prob, stream);

    TORCH_CHECK(softmax_success);

    // Matmul2
    gemm_switch_fp32accum(a_layout_n, b_layout_n, head_dim, q_seq_len, k_seq_len, alpha, v_lin_results_ptr, lead_dim,
        batch_stride, static_cast<half*>(dropout_results.data_ptr()), k_seq_len, k_seq_len * q_seq_len, beta_zero,
        static_cast<half*>(matmul2_results.data_ptr()), head_dim * attn_batches, head_dim, attn_batches);

    return {bmm1_results, dropout_results, dropout_mask, matmul2_results};
}

std::vector<torch::Tensor> bwd_cuda(int heads,
    torch::Tensor const& output_lin_grads,  // d_out
    torch::Tensor const& matmul2_results,   // S.*D*V ?
    torch::Tensor const& dropout_results,   // S.*D ?
    torch::Tensor const& bmm1_results,      // QxK'
    torch::Tensor const& pad_mask,          // attention mask
    torch::Tensor const& input_lin_results, // qkv
    torch::Tensor const& dropout_mask,      // D
    float dropout_prob)
{
    int const embed_dim = input_lin_results.size(-1) * heads;
    int const sequences = input_lin_results.size(1);
    int const q_seq_len = input_lin_results.size(0);
    int const k_seq_len = q_seq_len;
    int const batches = sequences * q_seq_len;
    int const head_dim = embed_dim / heads;
    // printf("b=%d s=%d h=%d d=%d\n", sequences, q_seq_len, heads, head_dim);
    int const output_lin_dim = 3 * embed_dim;
    int const attn_batches = heads * sequences;
    int const lead_dim = attn_batches * 3 * head_dim;
    int const batch_stride = 3 * head_dim;
    int const dropout_elems = attn_batches * q_seq_len * k_seq_len;
    float const alpha = 1.0;
    float const beta = 0.0;
    float const scale = 1.0 / sqrt(static_cast<float>(head_dim));

    // TODO: Streams can be used in Backprop but I haven't added more than one
    // in my first attempt to create the code
    cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
    cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
    cublasSetStream(handle, stream);

    // Intermediate Tensor Allocations
    at::Tensor matmul2_grads = torch::empty_like(dropout_results);
    at::Tensor input_lin_output_grads = torch::empty_like(input_lin_results);

    auto q_lin_results_ptr = static_cast<half*>(input_lin_results.data_ptr());
    auto k_lin_results_ptr = static_cast<half*>(input_lin_results.data_ptr()) + head_dim;
    auto v_lin_results_ptr = static_cast<half*>(input_lin_results.data_ptr()) + 2 * head_dim;

    auto q_lin_grads_ptr = static_cast<half*>(input_lin_output_grads.data_ptr());
    auto k_lin_grads_ptr = static_cast<half*>(input_lin_output_grads.data_ptr()) + head_dim;
    auto v_lin_grads_ptr = static_cast<half*>(input_lin_output_grads.data_ptr()) + 2 * head_dim;

    char a_layout_n{'n'};
    char a_layout_t{'t'};
    char b_layout_n{'n'};
    char b_layout_t{'t'};

    TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));

    // MatMul2 Dgrad1          d_out x V'
    gemm_switch_fp32accum(a_layout_t, b_layout_n, k_seq_len, q_seq_len, head_dim, alpha,
        static_cast<half const*>(v_lin_results_ptr), lead_dim, batch_stride,
        static_cast<half const*>(output_lin_grads.data_ptr()), head_dim * attn_batches, head_dim, beta,
        static_cast<half*>(matmul2_grads.data_ptr()), k_seq_len, k_seq_len * q_seq_len, attn_batches);

    // Matmul2 Dgrad2          (S * D)' x d_out
    gemm_switch_fp32accum(a_layout_n, b_layout_t, head_dim, k_seq_len, q_seq_len, alpha,
        static_cast<half const*>(output_lin_grads.data_ptr()), head_dim * attn_batches, head_dim,
        static_cast<half const*>(dropout_results.data_ptr()), k_seq_len, k_seq_len * q_seq_len, beta, v_lin_grads_ptr,
        lead_dim, batch_stride, attn_batches);

    // Apply Dropout Mask and Scale by Dropout Probability
    // Softmax Grad            dP
    dispatch_masked_scale_softmax_backward_recompute<half, half, float, false>(
        static_cast<half*>(matmul2_grads.data_ptr()), static_cast<half* const>(matmul2_grads.data_ptr()),
        reinterpret_cast<half const*>(bmm1_results.data_ptr()), reinterpret_cast<half const*>(pad_mask.data_ptr()),
        static_cast<uint8_t const*>(dropout_mask.data_ptr()), 1.0 / (1.0 - dropout_prob), k_seq_len, k_seq_len,
        attn_batches * q_seq_len / sequences, attn_batches * q_seq_len, stream);

    // Matmul1 Dgrad1          dQ = dP * k
    gemm_switch_fp32accum(a_layout_n, b_layout_n, head_dim, q_seq_len, k_seq_len, scale, k_lin_results_ptr, lead_dim,
        batch_stride, static_cast<half*>(matmul2_grads.data_ptr()), k_seq_len, k_seq_len * q_seq_len, beta,
        q_lin_grads_ptr, lead_dim, batch_stride, attn_batches);

    // Matmul1 Dgrad2          dK = dP' * q
    gemm_switch_fp32accum(a_layout_n, b_layout_t, head_dim, k_seq_len, q_seq_len, scale, q_lin_results_ptr, lead_dim,
        batch_stride, static_cast<half*>(matmul2_grads.data_ptr()), k_seq_len, k_seq_len * q_seq_len, beta,
        k_lin_grads_ptr, lead_dim, batch_stride, attn_batches);

    TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));

    return {input_lin_output_grads, matmul2_grads};
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
{
    m.doc() = "CUDA fused Multihead-Attention for BERT"; // optional module docstring
    m.def("fwd", &fwd_cuda, "Forward pass");
    m.def("bwd", &bwd_cuda, "Forward pass");
}
